from libauc.optimizers import PESG, Adam, SGD
from libauc.losses import AUCMLoss, CrossEntropyLoss, PSQLoss, PHLoss, PSHLoss, PLLoss, PSMLoss, PBHLoss
from libauc.models import ResNet20
from libauc.datasets import CIFAR10, CIFAR100, STL10, CAT_VS_DOG
from libauc.datasets import ImbalanceGenerator 
from torch_geometric.data import DataLoader as GeoLoader
from gnn import GNN
from loss import pAUC_fai_two, pAUC_fai, pAUC_KL, pAUC_CVaR, pAUC_mini, P_PUSH, SOAPLOSS, pAUC_KL_two, pAUC_mini_two
from SOPALoss import pAUC_CVaR_loss
from SOTAsLoss import tpAUC_KL_loss, tpAUC_CVaR_loss
from SOPA import SOPA
from SOTAs import SOTAs

import torch 
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from imbalanced_sampler import imbalanced_sampler
from sklearn.model_selection import KFold
import tensorflow as tf
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator

FLAGS = tf.compat.v1.flags.FLAGS
tf.compat.v1.flags.DEFINE_float('imbalanced_ratio', 0.5, 'controled imbalanced ratio for data')
tf.compat.v1.flags.DEFINE_string('loss', 'SOPA-s', 'loss functions for partial AUC, e.g. SOPA-s, SOPA, SOTA-s, OPMini, TPMini, P-PUSH, etc.')
tf.compat.v1.flags.DEFINE_string('optimizer', 'Adam', 'Adam or Momentum')
tf.compat.v1.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100')
tf.compat.v1.flags.DEFINE_float('lr', 0.001, 'learning rate')
tf.compat.v1.flags.DEFINE_float('momentum', 0.9, 'momentum for SGD')
tf.compat.v1.flags.DEFINE_float('decay', 2e-4, 'regularization weight decay')
tf.compat.v1.flags.DEFINE_string('activation', 'sigmoid', 'sigmoid, l2 or none')
tf.compat.v1.flags.DEFINE_integer('class_id', 0, 'class_id for dataset')
tf.compat.v1.flags.DEFINE_bool('pretrain', True, 'pretrain True or not')
tf.compat.v1.flags.DEFINE_float('moving_momentum', 0.0, 'momentum for moving average')
tf.compat.v1.flags.DEFINE_float('margin', 1.0, 'margin parameter for loss')

def pAUC_two_metric(target, pred, max_fpr):
  target = target.reshape(-1)
  pred = pred.reshape(-1)
  idx_pos = np.where(target == 1)[0]
  idx_neg = np.where(target != 1)[0]

  num_pos = round(len(idx_pos)*max_fpr)
  num_neg = round(len(idx_neg)*max_fpr)

  if num_pos<1:
    num_pos=1
  if num_neg<1:
    num_neg=1
  if len(idx_pos)==1: 
    selected_arg_pos = [0]
  else:
    selected_arg_pos = np.argpartition(pred[idx_pos], num_pos)[:num_pos]
  if len(idx_neg)==1: 
    selected_arg_neg = [0]
  else:
    selected_arg_neg = np.argpartition(-pred[idx_neg], num_neg)[:num_neg]

  selected_target = np.concatenate((target[idx_pos][selected_arg_pos], target[idx_neg][selected_arg_neg]))
  selected_pred = np.concatenate((pred[idx_pos][selected_arg_pos], pred[idx_neg][selected_arg_neg]))

  pAUC_score = roc_auc_score(selected_target, selected_pred)
  return pAUC_score

def erm_loss_eval(loss1=0, loss2=0, loss_func='CSQLoss'):
    if loss_func == 'CSQLoss':
      return loss1 + loss2**2
    elif loss_func == 'CLLoss':
      return loss1 + np.log(1+np.exp(loss2))
    elif loss_func == 'CHLoss':
      return loss1 + max(loss2, 0)
    elif loss_func == 'CSHLoss' or loss_func == 'AUCMLoss':
      return loss1 + max(loss2, 0)**2

def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class ImageDataset(Dataset):
    def __init__(self, images, targets, image_size=32, crop_size=30, mode='train'):
       self.images = images.astype(np.uint8)
       self.targets = targets
       self.mode = mode
       self.transform_train = transforms.Compose([                                                
                              transforms.ToTensor(),
                              transforms.RandomCrop((crop_size, crop_size), padding=None),
                              transforms.RandomHorizontalFlip(),
                              transforms.Resize((image_size, image_size)),
                              ])
       self.transform_test = transforms.Compose([
                             transforms.ToTensor(),
                             transforms.Resize((image_size, image_size)),
                              ])
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        target = self.targets[idx]
        image = Image.fromarray(image.astype('uint8'))
        if self.mode == 'train':
            image = self.transform_train(image)
        else:
            image = self.transform_test(image)
        return image, target, int(idx)

    def get_labels(self):
        return np.array(self.targets).reshape(-1)

class ChemicalDataset(Dataset):
    def __init__(self, datasource):
       self.targets=[]
       self.datasource=datasource
       self.ids = []
       if len(self.datasource.data.y.shape) > 1:
         self.datasource.data.y=datasource.data.y[:,FLAGS.class_id]
       for i in range(len(datasource)):
         if torch.isnan(self.datasource[i].y[0])==False:
           self.ids.append(i)
           self.targets.append(datasource[i].y.item())
       self.ids = torch.tensor(self.ids)
       self.datasource = self.datasource[self.ids]
       try:
         tmp=np.array(self.targets)
         pos = len(tmp[tmp==1])
         print('positive rate: '+ str(float(pos)/len(tmp)))
       except:
         print('positive rate error ')

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return (self.datasource[int(idx)], int(idx))

    def get_labels(self):
        return np.array(self.targets).reshape(-1)

class mapper():
    def __init__(self):
       self.idxmap = [{},{}]
       self.counts = [0,0]

    def transfer(self, ids, y):
        ids = ids.numpy()
        res = -np.ones(ids.shape).astype(int)
        for i in range(len(ids)):
          if self.idxmap[y].get(ids[i]) is None:
            self.idxmap[y][ids[i]] = self.counts[y]
            self.counts[y] += 1
          res[i] = self.idxmap[y][ids[i]]
        return res



# paramaters
SEED = 123
BATCH_SIZE = 64
imratio = FLAGS.imbalanced_ratio
lr = FLAGS.lr
gamma = 500
weight_decay = FLAGS.decay
margin = 1.0
set_all_seeds(SEED)

# dataloader
val_sample_ratio = 1.0
if FLAGS.dataset == 'cifar10':
  (train_data, train_label), (test_data, test_label) = CIFAR10()
elif FLAGS.dataset == 'cifar100':
  (train_data, train_label), (test_data, test_label) = CIFAR100()
elif FLAGS.dataset == 'stl10':
  (train_data, train_label), (test_data, test_label) = STL10()
elif FLAGS.dataset == 'catvsdog':
  (train_data, train_label), (test_data, test_label) = CAT_VS_DOG()
else:
  dataset = PygGraphPropPredDataset(name = FLAGS.dataset)
  split_idx = dataset.get_idx_split()
  traindSet = ChemicalDataset(dataset[torch.cat((split_idx['train'],split_idx['valid']))])
  testSet = ChemicalDataset(dataset[split_idx['test']])
  print('train num: '+ str(len(traindSet)))
  print('test num: '+ str(len(testSet)))
  conf={}
  if FLAGS.dataset in ['ogbg-molhiv','ogbg-molbbbp']:
    conf['pre_train'] = '../pretrained_models/' +  '_'.join([FLAGS.dataset, 'GIN', 'ce.ckpt'])
  else:
    conf['pre_train'] = '../pretrained_models/' +  '_'.join([FLAGS.dataset, str(FLAGS.class_id), 'GIN', 'ce.ckpt'])



if FLAGS.dataset in ['cifar10', 'cifar100', 'stl10', 'catvsdog']:
  (train_images, train_labels) = ImbalanceGenerator(train_data, train_label, imratio=0.1, shuffle=True, random_seed=SEED)
  (test_images, test_labels) = ImbalanceGenerator(test_data, test_label, is_balanced=True,  random_seed=SEED)
if FLAGS.dataset == 'cifar10' or FLAGS.dataset == 'cifar100':
  traindSet = ImageDataset(train_images, train_labels)
  testSet = ImageDataset(test_images, test_labels, mode = 'test')
elif FLAGS.dataset == 'stl10':
  traindSet = ImageDataset(train_images, train_labels, image_size = 96, crop_size = 90)
  testSet = ImageDataset(test_images, test_labels, image_size=96, crop_size=96, mode = 'test')
elif FLAGS.dataset == 'catvsdog':
  traindSet = ImageDataset(train_images, train_labels, image_size = 50, crop_size = 47)
  testSet = ImageDataset(test_images, test_labels, image_size=50, crop_size=50, mode = 'test')

# You need to include sigmoid activation in the last layer for any customized models!
kf = KFold(n_splits=5)
tmpX = np.zeros((len(traindSet),1))

if FLAGS.activation == 'sigmoid' or FLAGS.activation == 'l2':
  parameter_set = [0.1, 0.5, 1.0]
else:
  parameter_set = [0.1, 1, 10]

if FLAGS.loss in ['SOPA', 'OPMini', 'SONX']:
  parameter_set = [0.1, 0.3, 0.5]
if FLAGS.loss in ['TPMini']:
  parameter_set = [0.3, 0.4, 0.5]
if FLAGS.loss in ['SOPA-s', 'SOTA-s']:
  parameter_set = [0.1, 1.0, 10.0]
if FLAGS.loss in ['P-PUSH']:
  parameter_set = [2, 4, 6]
if FLAGS.loss in ['OPFAI-p', 'TPFAI-p']:
  parameter_set = [101,34,11]
if FLAGS.loss in ['OPFAI-e', 'TPFAI-e']:
  parameter_set = [30, 20, 8]
if FLAGS.loss in ['TOPPUSH']:
  parameter_set = [0]
if FLAGS.loss in ['AUCMLoss']:
  parameter_set = [100, 500, 1000]
if FLAGS.loss in ['PSHLoss']:
  parameter_set = [1.0]



if FLAGS.dataset in ['cifar10', 'cifar100', 'stl10', 'catvsdog']:
  testloader =  torch.utils.data.DataLoader(testSet, batch_size=32, num_workers=2, shuffle=False)
else:
  testloader =  GeoLoader(testSet, batch_size=BATCH_SIZE, shuffle=False)

part = 0
print ('Start Training')
print ('-'*30)
pos_num = round(imratio*BATCH_SIZE) 
neg_num = BATCH_SIZE - pos_num

for train_id, val_id in kf.split(tmpX):
  mapfunc = mapper()
  for para in parameter_set:
    if FLAGS.dataset in ['cifar10', 'cifar100', 'stl10', 'catvsdog']:
      tr_sampler = imbalanced_sampler(data_source=traindSet, imratio=imratio, idx=train_id, shuffle=False)
      trainloader =  torch.utils.data.DataLoader(dataset=traindSet, sampler=tr_sampler, batch_size=BATCH_SIZE, num_workers=1, shuffle=False)
      validloader =  torch.utils.data.DataLoader(dataset=traindSet, sampler=imbalanced_sampler(data_source=traindSet,idx=val_id,sample_scale=val_sample_ratio), batch_size=BATCH_SIZE, num_workers=1, shuffle=False)
    else:
      tr_sampler = imbalanced_sampler(data_source=traindSet, imratio=imratio, idx=train_id, shuffle=True)
      trainloader =  GeoLoader(dataset=traindSet, sampler=tr_sampler, batch_size=BATCH_SIZE, num_workers=0, shuffle=False)
      validloader =  GeoLoader(dataset=traindSet, sampler=imbalanced_sampler(data_source=traindSet,idx=val_id,sample_scale=val_sample_ratio), batch_size=BATCH_SIZE, num_workers=0, shuffle=False)
    
    if FLAGS.dataset in ['cifar10', 'cifar100', 'stl10', 'catvsdog']:
      model = ResNet20(pretrained=False, num_classes=1, last_activation=FLAGS.activation)
      model = model.cuda() 
    else:
      model = GNN(gnn_type = 'gin', num_tasks = 1, num_layer = 5, emb_dim = 64, drop_ratio = 0.5, virtual_node = False, last_activation=FLAGS.activation).cuda()
      if FLAGS.pretrain == True:
        model.load_state_dict(torch.load(conf['pre_train']), strict=False)
        model.graph_pred_linear.reset_parameters()
    pos_length=tr_sampler.get_pos_len()
    neg_length=tr_sampler.get_neg_len()
    # define loss & optimizer
    if FLAGS.loss == 'SOPA-s':
      Loss = pAUC_KL(pos_length=pos_length, Lambda=para)
    elif FLAGS.loss == 'SONX':
      Loss = tpAUC_CVaR_loss(data_length=pos_length, threshold=FLAGS.margin, rate=para, momentum=FLAGS.moving_momentum)
    elif FLAGS.loss == 'SOTA-s':
      Loss = tpAUC_KL_loss(pos_length=pos_length, threshold=FLAGS.margin, Lambda=para, tau=para)
      optimizer = SOTAs(model, loss=Loss, lr=lr, weight_decay=weight_decay) 
    elif FLAGS.loss == 'SOPA':
      Loss = pAUC_CVaR_loss(pos_length=pos_length, num_neg=neg_num, beta=para)
      optimizer = SOPA(model, loss=Loss, lr=lr, eta=0.9, weight_decay=weight_decay) # beta_0 -> eta_2 
    elif FLAGS.loss == 'TOPPUSH':
      Loss = pAUC_CVaR_loss(pos_length=pos_length, num_neg=neg_num, gamma=para, toppush=True)
    elif FLAGS.loss == 'OPMini':
      Loss = pAUC_mini(num_neg=neg_num, gamma=para)
    elif FLAGS.loss == 'TPMini':
      Loss = pAUC_mini_two(num_pos=pos_num, num_neg=neg_num, gamma=para)
    elif FLAGS.loss == 'P-PUSH':
      Loss = P_PUSH(pos_length=pos_length, poly=para)
    elif FLAGS.loss == 'AUCMLoss':
      Loss = AUCMLoss(margin=1.0)
    elif FLAGS.loss == 'OPFAI-p':
      Loss = pAUC_fai(gamma=para, p_type='poly')
    elif FLAGS.loss == 'OPFAI-e':
      Loss = pAUC_fai(gamma=para, p_type='exp')
    elif FLAGS.loss == 'TPFAI-p':
      Loss = pAUC_fai_two(gamma=para, p_type='poly')
    elif FLAGS.loss == 'TPFAI-e':
      Loss = pAUC_fai_two(gamma=para, p_type='exp')
    elif FLAGS.loss == 'PSHLoss':
      Loss = PSHLoss(margin=para)
    if FLAGS.loss not in ['SOPA', 'SOTA-s']:
      if FLAGS.optimizer == 'Adam':
        optimizer = Adam(model, lr=lr, weight_decay=weight_decay) 
      elif FLAGS.optimizer == 'Momentum':
        optimizer = SGD(model, lr=lr, weight_decay=weight_decay, momentum=FLAGS.momentum) 
    if FLAGS.loss == 'AUCMLoss':
      optimizer = PESG(model, 
                       a=Loss.a, 
                       b=Loss.b, 
                       alpha=Loss.alpha, 
                       imratio=imratio, 
                       lr=lr,
                       gamma=para, 
                       margin=1.0, 
                       weight_decay=weight_decay)
    print('Margin=%s, part=%s'%(para, part))
    for epoch in range(61):
      tr_loss = 0
      tr_loss_1 = 0
      tr_loss_2 = 0
      if epoch==21 or epoch==41:
          optimizer.update_stepsize(decay_factor=10)
      idx=0
      if epoch > 0: # eval the model at epoch 0 without training
        for idx, data in enumerate(trainloader):
            if FLAGS.dataset in ['stl10','cifar10','cifar100','catvsdog']:
              train_data, train_labels, ids = data
              train_data, train_labels = train_data.cuda(), train_labels.cuda()
              y_pred = model(train_data)
            else: 
              data, ids = data
              data = data.cuda()
              train_labels = data.y
              y_pred = model(data)
            if FLAGS.loss in ['SOPA-s','SOPA','P-PUSH', 'SOTA-s', 'SONX', 'TOPPUSH']:
              loss = Loss(y_pred, train_labels, mapfunc.transfer(ids[:pos_num], 1), mapfunc.transfer(ids[pos_num:],0))
            elif FLAGS.loss in ['AUCMLoss']:
              loss, real_loss_1, real_loss_2 = Loss(y_pred, train_labels)
            else:
              loss = Loss(y_pred, train_labels)
            if FLAGS.loss in ['AUCMLoss']:
              tr_loss_1 = tr_loss_1  + real_loss_1.cpu().detach().numpy()
              tr_loss_2 = tr_loss_2  + real_loss_2.cpu().detach().numpy()
            else:
              tr_loss = tr_loss  + loss.cpu().detach().numpy()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if FLAGS.loss in ['AUCMLoss']:
          tr_loss_1 = tr_loss_1/(idx+1)
          tr_loss_2 = tr_loss_2/(idx+1)
          tr_loss = erm_loss_eval(tr_loss_1, tr_loss_2, FLAGS.loss)
        else:
          tr_loss = tr_loss/(idx+1)
        print ('Epoch=%s, BatchID=%s, training_loss=%.4f, lr=%.4f'%(epoch, idx, tr_loss,  optimizer.lr))
      ############### evaluation block ################
      model.eval()
      with torch.no_grad():    
        train_pred = []
        train_true = [] 
        for jdx, data in enumerate(trainloader):
          if FLAGS.dataset in ['stl10','cifar10','cifar100','catvsdog']:
            train_data, train_labels, _ = data
            train_data = train_data.cuda()
            y_pred = model(train_data)
          else: 
            data, ids = data
            train_labels = data.y
            data = data.cuda()
            y_pred = model(data)     
          train_pred.append(y_pred.cpu().detach().numpy())
          train_true.append(train_labels.numpy())
        train_true = np.concatenate(train_true)
        train_pred = np.concatenate(train_pred)
        single_tr_auc_1 =  roc_auc_score(train_true, train_pred, max_fpr = 0.1) 
        single_tr_auc_2 =  roc_auc_score(train_true, train_pred, max_fpr = 0.3) 
        single_tr_auc_3 =  roc_auc_score(train_true, train_pred, max_fpr = 0.5) 
        TP_tr_auc_1 =  pAUC_two_metric(train_true, train_pred, max_fpr = 0.3) 
        TP_tr_auc_2 =  pAUC_two_metric(train_true, train_pred, max_fpr = 0.4) 
        TP_tr_auc_3 =  pAUC_two_metric(train_true, train_pred, max_fpr = 0.5) 
        test_pred = []
        test_true = [] 
        for jdx, data in enumerate(testloader):
          if FLAGS.dataset in ['stl10','cifar10','cifar100','catvsdog','melanoma']:
            test_data, test_labels, _ = data
            test_data = test_data.cuda()
            y_pred = model(test_data)
          else: 
            data, ids = data
            test_labels = data.y
            data = data.cuda()
            y_pred = model(data)          
          test_pred.append(y_pred.cpu().detach().numpy())
          test_true.append(test_labels.numpy())
        test_true = np.concatenate(test_true)
        test_pred = np.concatenate(test_pred)
        single_te_auc_1 =  roc_auc_score(test_true, test_pred, max_fpr = 0.1) 
        single_te_auc_2 =  roc_auc_score(test_true, test_pred, max_fpr = 0.3) 
        single_te_auc_3 =  roc_auc_score(test_true, test_pred, max_fpr = 0.5) 
        TP_te_auc_1 =  pAUC_two_metric(test_true, test_pred, max_fpr = 0.3) 
        TP_te_auc_2 =  pAUC_two_metric(test_true, test_pred, max_fpr = 0.4) 
        TP_te_auc_3 =  pAUC_two_metric(test_true, test_pred, max_fpr = 0.5) 
        val_pred = []
        val_true = [] 
        for jdx, data in enumerate(validloader):
          if FLAGS.dataset in ['stl10','cifar10','cifar100','catvsdog','melanoma']:
            val_data, val_labels, _ = data
            val_data = val_data.cuda()
            y_pred = model(val_data)
          else: 
            data, ids = data
            val_labels = data.y
            data = data.cuda()
            y_pred = model(data)  
          val_pred.append(y_pred.cpu().detach().numpy())
          val_true.append(val_labels.numpy())
        val_true = np.concatenate(val_true)
        val_pred = np.concatenate(val_pred)
        single_val_auc_1 =  roc_auc_score(val_true, val_pred, max_fpr = 0.1) 
        single_val_auc_2 =  roc_auc_score(val_true, val_pred, max_fpr = 0.3) 
        single_val_auc_3 =  roc_auc_score(val_true, val_pred, max_fpr = 0.5) 
        TP_val_auc_1 =  pAUC_two_metric(val_true, val_pred, max_fpr = 0.3) 
        TP_val_auc_2 =  pAUC_two_metric(val_true, val_pred, max_fpr = 0.4) 
        TP_val_auc_3 =  pAUC_two_metric(val_true, val_pred, max_fpr = 0.5) 
        print('Epoch=%s, BatchID=%s, Tr_AUC(0.1)=%.4f, TP_Tr_AUC(0.3)=%.4f, Val_AUC(0.1)=%.4f, TP_Val_AUC(0.3)=%.4f, Test_AUC(0.1)=%.4f, TP_Test_AUC(0.3)=%.4f, Tr_AUC(0.3)=%.4f, TP_Tr_AUC(0.4)=%.4f, Val_AUC(0.3)=%.4f, TP_Val_AUC(0.4)=%.4f, Test_AUC(0.3)=%.4f, TP_Test_AUC(0.4)=%.4f, Tr_AUC(0.5)=%.4f, TP_Tr_AUC(0.5)=%.4f, Val_AUC(0.5)=%.4f, TP_Val_AUC(0.5)=%.4f, Test_AUC(0.5)=%.4f, TP_Test_AUC(0.5)=%.4f, lr=%.4f \n'%(epoch, idx, single_tr_auc_1, TP_tr_auc_1, single_val_auc_1, TP_val_auc_1, single_te_auc_1, TP_te_auc_1, single_tr_auc_2, TP_tr_auc_2, single_val_auc_2, TP_val_auc_2, single_te_auc_2, TP_te_auc_2, single_tr_auc_3, TP_tr_auc_3, single_val_auc_3, TP_val_auc_3, single_te_auc_3, TP_te_auc_3,  optimizer.lr))
      model.train()
      ############### evaluation end ################
  part += 1 

